"""Nodule analysis and diagnosis: seeing whether a nodule is cancer
1. Generate nodule candicates
  1.1 Segmentation: predict if a given pixel is part of a nodule. This will be
      done per 2D slice, and every 2D result will be stacked to form a 3D array
      of voxels containing nodule candidate predictions.
  1.2 Grouping: group the voxels into nodule candidates by applying a threshold
      to the predictions, and then grouping connected regions of flagged voxels.
  1.3 Constructing sample tuples: each identified nodule candidate will be used
      to construct a sample tuple for classification.
2. Classify nodules and malignancy
  2.1 Nodule classification: each nodule candidate from segmentation and
      grouping will be classified as either nodule or non-nodule.
  2.2 ROC/AUC metrics: before classification, define some new metrics for
      examining the performance of classification models, as well as establish 
      a baseline metric against which to compare our malignancy classifiers.
  2.3 Fine-tuning the malignancy model: define a model specifically for
      classifying benign and malignant nodules, train it by fine-tuning, and
      see how it performs. 
3. End-to-end detection
   Answer the question: Are there malignant nodules present in the lungs?
  3.1 IRC: segment CT to get nodule candidate samples to classify.
  3.2 Determine the nodules: perform nodule classification on the candidate to
      determine whether it should be fed into the malignancy classifier.
  3.3 Determine malignancy: perform malignancy classification on the nodules.
"""

import os
import glob
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim
import scipy.ndimage as ndimage
import scipy.ndimage.morphology as morphology
import cls.net

from torch.utils.data import DataLoader
from seg.dataset.luna_2d_seg import Luna2dSegmentationDataset
from seg.net.unet_wrapper import UNetWrapper
from cls.dataset.cache import getCt
from utils.candidate import (
    getCandidateInfoList,
    getCandidateInfoDict,
    CandidateInfoTuple,
)
from cls.dataset.luna import LunaDataset
from utils.enumerate import enumerateWithEstimate
from utils.coordinate import irc2xyz
from config import DATA_PATH
from utils.logconf import logging

log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)


def print_confusion(label, confusions, do_mal):
    row_labels = ["Non-Nodules", "Benign", "Malignant"]

    if do_mal:
        col_labels = [
            "",
            "Complete Miss",
            "Filtered Out",
            "Pred. Benign",
            "Pred. Malignant",
        ]
    else:
        col_labels = ["", "Complete Miss", "Filtered Out", "Pred. Nodule"]
        confusions[:, -2] += confusions[:, -1]
        confusions = confusions[:, :-1]
    cell_width = 16
    f = "{:>" + str(cell_width) + "}"
    print(label)
    print(" | ".join([f.format(s) for s in col_labels]))
    for i, (l, r) in enumerate(zip(row_labels, confusions)):
        r = [l] + list(r)
        if i == 0:
            r[1] = ""
        print(" | ".join([f.format(i) for i in r]))


def match_and_score(detections, truth, threshold=0.5, threshold_mal=0.5):
    # Returns 3x4 confusion matrix for:
    # Rows: Truth: Non-Nodules, Benign, Malignant
    # Cols: Not Detected, Detected by Seg, Detected as Benign, Detected as Malignant
    # If one true nodule matches multiple detections, the "highest" detection is considered
    # If one detection matches several true nodule annotations, it counts for all of them
    true_nodules = [c for c in truth if c.isNodule_bool]
    truth_diams = np.array([c.diameter_mm for c in true_nodules])
    truth_xyz = np.array([c.center_xyz for c in true_nodules])

    detected_xyz = np.array([n[2] for n in detections])
    # detection classes will contain
    # 1 -> detected by seg but filtered by cls
    # 2 -> detected as benign nodule (or nodule if no malignancy model is used)
    # 3 -> detected as malignant nodule (if applicable)
    detected_classes = np.array(
        [1 if d[0] < threshold else (2 if d[1] < threshold else 3) for d in detections]
    )

    confusion = np.zeros((3, 4), dtype=np.int32)
    if len(detected_xyz) == 0:
        for tn in true_nodules:
            confusion[2 if tn.isMal_bool else 1, 0] += 1
    elif len(truth_xyz) == 0:
        for dc in detected_classes:
            confusion[0, dc] += 1
    else:
        normalized_dists = (
            np.linalg.norm(truth_xyz[:, None] - detected_xyz[None], ord=2, axis=-1)
            / truth_diams[:, None]
        )
        matches = normalized_dists < 0.7
        unmatched_detections = np.ones(len(detections), dtype=np.bool_)
        matched_true_nodules = np.zeros(len(true_nodules), dtype=np.int32)
        for i_tn, i_detection in zip(*matches.nonzero()):
            matched_true_nodules[i_tn] = max(
                matched_true_nodules[i_tn], detected_classes[i_detection]
            )
            unmatched_detections[i_detection] = False

        for ud, dc in zip(unmatched_detections, detected_classes):
            if ud:
                confusion[0, dc] += 1
        for tn, dc in zip(true_nodules, matched_true_nodules):
            confusion[2 if tn.isMal_bool else 1, dc] += 1
    return confusion


class NoduleAnalysisApp:
    def __init__(self, cli_args):
        self.cli_args = cli_args
        # self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')

        if not (bool(self.cli_args.series_uid) ^ self.cli_args.run_validation):
            raise Exception(
                "One and only one of series_uid and --run-validation should be given"
            )

        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device("cuda" if self.use_cuda else "cpu")

        if not self.cli_args.segmentation_path:
            self.cli_args.segmentation_path = self.initModelPath("seg")

        if not self.cli_args.classification_path:
            self.cli_args.classification_path = self.initModelPath("cls")

        self.seg_model, self.cls_model, self.malignancy_model = self.initModels()

    def initModelPath(self, type_str):
        local_path = os.path.join(
            DATA_PATH,
            "models",
            type_str,
            type_str + "_{}_{}.{}.state".format("*", "*", "best"),
        )

        file_list = glob.glob(local_path)
        file_list.sort()

        try:
            return file_list[-1]
        except IndexError:
            log.debug([local_path, file_list])
            raise

    def initModels(self):
        log.debug(self.cli_args.segmentation_path)
        seg_dict = torch.load(self.cli_args.segmentation_path)

        seg_model = UNetWrapper(
            in_channels=7,
            n_classes=1,
            depth=3,
            wf=4,
            padding=True,
            batch_norm=True,
            up_mode="upconv",
        )

        seg_model.load_state_dict(seg_dict["model_state"])
        seg_model.eval()

        log.debug(self.cli_args.classification_path)
        cls_dict = torch.load(self.cli_args.classification_path)

        model_cls = getattr(cls.net, self.cli_args.cls_model)
        cls_model = model_cls()
        cls_model.load_state_dict(cls_dict["model_state"])
        cls_model.eval()

        if self.use_cuda:
            if torch.cuda.device_count() > 1:
                seg_model = nn.DataParallel(seg_model)
                cls_model = nn.DataParallel(cls_model)

            seg_model.to(self.device)
            cls_model.to(self.device)

        if self.cli_args.malignancy_path:
            model_cls = getattr(cls.net, self.cli_args.malignancy_model)
            malignancy_model = model_cls()
            malignancy_dict = torch.load(self.cli_args.malignancy_path)
            malignancy_model.load_state_dict(malignancy_dict["model_state"])
            malignancy_model.eval()
            if self.use_cuda:
                malignancy_model.to(self.device)
        else:
            malignancy_model = None
        return seg_model, cls_model, malignancy_model

    def initSegmentationDl(self, series_uid):
        seg_ds = Luna2dSegmentationDataset(
            contextSlices_count=3,
            series_uid=series_uid,
            fullCt_bool=True,
        )
        seg_dl = DataLoader(
            seg_ds,
            batch_size=self.cli_args.batch_size
            * (torch.cuda.device_count() if self.use_cuda else 1),
            num_workers=self.cli_args.num_workers,
            pin_memory=self.use_cuda,
        )

        return seg_dl

    def initClassificationDl(self, candidateInfo_list):
        cls_ds = LunaDataset(
            sortby_str="series_uid",
            candidateInfo_list=candidateInfo_list,
        )
        cls_dl = DataLoader(
            cls_ds,
            batch_size=self.cli_args.batch_size
            * (torch.cuda.device_count() if self.use_cuda else 1),
            num_workers=self.cli_args.num_workers,
            pin_memory=self.use_cuda,
        )

        return cls_dl

    def main(self):
        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))

        val_ds = LunaDataset(
            val_stride=10,
            isValSet_bool=True,
        )
        val_set = set(
            candidateInfo_tup.series_uid
            for candidateInfo_tup in val_ds.candidateInfo_list
        )
        positive_set = set(
            candidateInfo_tup.series_uid
            for candidateInfo_tup in getCandidateInfoList()
            if candidateInfo_tup.isNodule_bool
        )

        if self.cli_args.series_uid:
            series_set = set(self.cli_args.series_uid.split(","))
        else:
            series_set = set(
                candidateInfo_tup.series_uid
                for candidateInfo_tup in getCandidateInfoList()
            )

        if self.cli_args.include_train:
            train_list = sorted(series_set - val_set)
        else:
            train_list = []
        val_list = sorted(series_set & val_set)

        candidateInfo_dict = getCandidateInfoDict()
        series_iter = enumerateWithEstimate(
            val_list + train_list,
            "Series",
        )
        all_confusion = np.zeros((3, 4), dtype=np.int32)
        for _, series_uid in series_iter:
            ct = getCt(series_uid)
            mask_a = self.segmentCt(ct, series_uid)  # segmentation

            candidateInfo_list = self.groupSegmentationOutput(
                series_uid, ct, mask_a
            )  # grouping
            classifications_list = self.classifyCandidates(
                ct, candidateInfo_list
            )  # classification

            if not self.cli_args.run_validation:
                print(f"found nodule candidates in {series_uid}:")
                for prob, prob_mal, center_xyz, center_irc in classifications_list:
                    if prob > 0.5:
                        s = f"nodule prob {prob:.3f}, "
                        if self.malignancy_model:
                            s += f"malignancy prob {prob_mal:.3f}, "
                        s += f"center xyz {center_xyz}"
                        print(s)

            # If we have the ground truth data, we compute and print the confusion
            # matrix and also add the current results to the total
            if series_uid in candidateInfo_dict:
                one_confusion = match_and_score(
                    classifications_list, candidateInfo_dict[series_uid]
                )
                all_confusion += one_confusion
                print_confusion(
                    series_uid, one_confusion, self.malignancy_model is not None
                )

        print_confusion("Total", all_confusion, self.malignancy_model is not None)

    def classifyCandidates(self, ct, candidateInfo_list):
        cls_dl = self.initClassificationDl(candidateInfo_list)
        classifications_list = []
        for batch_ndx, batch_tup in enumerate(cls_dl):
            input_t, _, _, series_list, center_list = batch_tup

            input_g = input_t.to(self.device)
            with torch.no_grad():
                _, probability_nodule_g = self.cls_model(input_g)
                if self.malignancy_model is not None:
                    _, probability_mal_g = self.malignancy_model(input_g)
                else:
                    probability_mal_g = torch.zeros_like(probability_nodule_g)

            zip_iter = zip(
                center_list,
                probability_nodule_g[:, 1].tolist(),
                probability_mal_g[:, 1].tolist(),
            )
            for center_irc, prob_nodule, prob_mal in zip_iter:
                center_xyz = irc2xyz(
                    center_irc,
                    direction_a=ct.direction_a,
                    origin_xyz=ct.origin_xyz,
                    vxSize_xyz=ct.vxSize_xyz,
                )
                cls_tup = (prob_nodule, prob_mal, center_xyz, center_irc)
                classifications_list.append(cls_tup)
        return classifications_list

    def segmentCt(self, ct, series_uid):
        """Perform segmentation on every slice of the entire CT scan.
        We build a dataset that loads a single CT, feed it slice by slice to
        the segmentation model, the output of which is an array of per-pixel
        probabilities that the given pixel is part of a nodule.

        Afterward, we threshold the predictions to get a binary array.
        """
        with torch.no_grad():
            output_a = np.zeros_like(ct.hu_a, dtype=np.float32)
            seg_dl = self.initSegmentationDl(series_uid)
            for input_t, _, _, slice_ndx_list in seg_dl:
                input_g = input_t.to(self.device)
                prediction_g = self.seg_model(input_g)

                for i, slice_ndx in enumerate(slice_ndx_list):
                    output_a[slice_ndx] = prediction_g[i].cpu().numpy()

            mask_a = output_a > 0.5
            # deletes one layer of boundary voxels and only keeps the inner
            # ones, this makes the flagged area smaller and causes very small
            # components to vanish.
            mask_a = morphology.binary_erosion(mask_a, iterations=1)

        return mask_a

    def groupSegmentationOutput(self, series_uid, ct, clean_a):
        """Group voxels into nodule candidates
        Use a simple connected-components algorithm for grouping our suspected
        nodule voxels into chunks to feed into classification.

        We get nodule locations for voxel-wise detection
        """
        # Take all nonzero pixels that share an edge with another nonzero pixel
        # and mark them as belonging to the same group.
        candidateLabel_a, candidate_count = ndimage.label(clean_a)
        # Gets the center of mass for each group as index, row, column coordinates
        centerIrc_list = ndimage.center_of_mass(
            ct.hu_a.clip(-1000, 1000) + 1001,  # the mass is non-negative
            labels=candidateLabel_a,
            index=np.arange(1, candidate_count + 1),
        )

        candidateInfo_list = []
        for i, center_irc in enumerate(centerIrc_list):
            # Converts the voxel coordinates to real patient coordinates
            center_xyz = irc2xyz(
                center_irc,
                ct.origin_xyz,
                ct.vxSize_xyz,
                ct.direction_a,
            )
            assert np.all(np.isfinite(center_irc)), repr(
                ["irc", center_irc, i, candidate_count]
            )
            assert np.all(np.isfinite(center_xyz)), repr(["xyz", center_xyz])
            candidateInfo_tup = CandidateInfoTuple(
                False, False, False, 0.0, series_uid, center_xyz
            )
            candidateInfo_list.append(candidateInfo_tup)

        return candidateInfo_list

    def logResults(self, mode_str, filtered_list, series2diagnosis_dict, positive_set):
        count_dict = {"tp": 0, "tn": 0, "fp": 0, "fn": 0}
        for series_uid in filtered_list:
            probablity_float, center_irc = series2diagnosis_dict.get(
                series_uid, (0.0, None)
            )
            if center_irc is not None:
                center_irc = tuple(int(x.item()) for x in center_irc)
            positive_bool = series_uid in positive_set
            prediction_bool = probablity_float > 0.5
            correct_bool = positive_bool == prediction_bool

            if positive_bool and prediction_bool:
                count_dict["tp"] += 1
            if not positive_bool and not prediction_bool:
                count_dict["tn"] += 1
            if not positive_bool and prediction_bool:
                count_dict["fp"] += 1
            if positive_bool and not prediction_bool:
                count_dict["fn"] += 1

            log.info(
                "{} {} Label:{!r:5} Pred:{!r:5} Correct?:{!r:5} Value:{:.4f} {}".format(
                    mode_str,
                    series_uid,
                    positive_bool,
                    prediction_bool,
                    correct_bool,
                    probablity_float,
                    center_irc,
                )
            )

        total_count = sum(count_dict.values())
        percent_dict = {k: v / (total_count or 1) * 100 for k, v in count_dict.items()}

        precision = percent_dict["p"] = count_dict["tp"] / (
            (count_dict["tp"] + count_dict["fp"]) or 1
        )
        recall = percent_dict["r"] = count_dict["tp"] / (
            (count_dict["tp"] + count_dict["fn"]) or 1
        )
        percent_dict["f1"] = 2 * (precision * recall) / ((precision + recall) or 1)

        log.info(
            mode_str
            + " tp:{tp:.1f}%, tn:{tn:.1f}%, fp:{fp:.1f}%, fn:{fn:.1f}%".format(
                **percent_dict,
            )
        )
        log.info(
            mode_str
            + " precision:{p:.3f}, recall:{r:.3f}, F1:{f1:.3f}".format(
                **percent_dict,
            )
        )


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--batch-size",
        default=4,
        type=int,
        help="Batch size to use for training",
    )
    parser.add_argument(
        "--num-workers",
        default=4,
        type=int,
        help="Number of worker processes for background data loading",
    )

    parser.add_argument(
        "--run-validation",
        action="store_true",
        default=False,
        help="Run over validation rather than a single CT.",
    )
    parser.add_argument(
        "--include-train",
        action="store_true",
        default=False,
        help="Include data that was in the training set. (default: validation data only)",
    )

    parser.add_argument(
        "--segmentation-path",
        nargs="?",
        default="data/models/seg/seg_2020-01-26_19.45.12_w4d3c1-bal_1_nodupe-label_pos-d1_fn8-adam.best.state",
        help="Path to the saved segmentation model",
    )

    parser.add_argument(
        "--cls-model",
        action="store",
        default="LunaModel",
        help="What to model class name to use for the classifier.",
    )
    parser.add_argument(
        "--classification-path",
        nargs="?",
        default="data/models/cls/cls_2020-02-06_14.16.55_final-nodule-nonnodule.best.state",
        help="Path to the saved classification model",
    )

    parser.add_argument(
        "--malignancy-model",
        action="store",
        default="LunaModel",
        help="What to model class name to use for the malignancy classifier.",
    )
    parser.add_argument(
        "--malignancy-path",
        nargs="?",
        default=None,
        help="Path to the saved malignancy classification model",
    )

    parser.add_argument(
        "series_uid",
        nargs="?",
        default=None,
        help="Series UID to use.",
    )

    return parser.parse_args()


if __name__ == "__main__":
    NoduleAnalysisApp(parse_args()).main()
